/*
 * Copyright (c) 2021 The red-star Project
 *
 * Licensed under the Apache License, version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at:
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.inyourcode.core.compile;

import com.inyourcode.core.util.StackTraceUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

import javax.tools.FileObject;
import javax.tools.ForwardingJavaFileManager;
import javax.tools.JavaCompiler;
import javax.tools.JavaFileManager;
import javax.tools.JavaFileObject;
import javax.tools.SimpleJavaFileObject;
import javax.tools.ToolProvider;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/**
 * @author JackLei
 */
public class DynamicCompiler {
    private static final Logger LOGGER = LoggerFactory.getLogger(DynamicCompiler.class);

    public static Class compileByJavaFile(String className, String javaFilePath) {
        try {
            ClassPathResource classPathResource = new ClassPathResource(javaFilePath);
            byte[] bytes = Files.readAllBytes(classPathResource.getFile().toPath());
            String sourceCode = new String(bytes, Charset.forName("UTF-8"));
            return compileBySourceCode(className, sourceCode);
        } catch (IOException e) {
            LOGGER.error("Dynamic compilation Java file [{}] failed, exception:{}", className, StackTraceUtil.stackTrace(e));
        }
        return null;
    }

    /**
     * @param className
     * @param sourceCode
     * @return
     */
    public static Class compileBySourceCode(String className, String sourceCode) {
        //javac编译器
        JavaCompiler compiler = ToolProvider.getSystemJavaCompiler();
        try {
            //将源代码文件抽象为JavaFileObject
            JavaFileObject compilationUnit =
                    new JavaSourceFileObject(className, sourceCode);
            //定义一个ForwardingJavaFileManager，通过这个管理对象重写openOutputStream方法，在writeclass时，将编译后的.class的字节流写入到自定义的ClassFileObject#outputStream中
            MyForwardingJavaFileManager fileManager =
                    new MyForwardingJavaFileManager(compiler.getStandardFileManager(null, null, null));
            //定义一个任务，执行编译功能
            JavaCompiler.CompilationTask compilationTask = compiler.getTask(
                    null, fileManager, null, null, null, Arrays.asList(compilationUnit));
            //开始编译
            compilationTask.call();
            //编译后的class的字节码，会被MyForwardingJavaFileManager持有，因为MyForwardingJavaFileManager重写了openOutputStream的逻辑
            CompiledClassLoader classLoader =
                    new CompiledClassLoader(fileManager.getGeneratedOutputFiles());
            //将字节码转成Class对象

            return classLoader.loadClass(className);
        } catch (ClassNotFoundException e) {
            LOGGER.error("Dynamic compilation class[{}] failed, exception:{}", className, StackTraceUtil.stackTrace(e));
        }
        return null;
    }

    /**
     * 抽象.java源文件
     */
    private static class JavaSourceFileObject extends SimpleJavaFileObject {
        private final String code;

        public JavaSourceFileObject(String name, String code) {
            super(URI.create("string:///" + name.replace('.', '/') + Kind.SOURCE.extension),
                    Kind.SOURCE);
            this.code = code;
        }

        @Override
        public CharSequence getCharContent(boolean ignoreEncodingErrors) throws IOException {
            return code;
        }
    }

    /**
     * 抽象.class源文件
     */
    private static class ClassFileObject extends SimpleJavaFileObject {
        private final ByteArrayOutputStream outputStream;
        private final String className;

        protected ClassFileObject(String className, Kind kind) {
            super(URI.create("mem:///" + className.replace('.', '/') + kind.extension), kind);
            this.className = className;
            outputStream = new ByteArrayOutputStream();
        }

        @Override
        public OutputStream openOutputStream() throws IOException {
            return outputStream;
        }

        public byte[] getBytes() {
            return outputStream.toByteArray();
        }

        public String getClassName() {
            return className;
        }
    }

    /**
     * 监听ClassWriter写文件时，把编译好的.class的字节流，写入到自定义的ForwardingJavaFileManager中
     */
    private static class MyForwardingJavaFileManager extends ForwardingJavaFileManager {
        private final List<ClassFileObject> outputFiles;

        protected MyForwardingJavaFileManager(JavaFileManager fileManager) {
            super(fileManager);
            outputFiles = new ArrayList<ClassFileObject>();
        }

        @Override
        public JavaFileObject getJavaFileForOutput(Location location, String className, JavaFileObject.Kind kind, FileObject sibling) throws IOException {
            ClassFileObject file = new ClassFileObject(className, kind);
            outputFiles.add(file);
            return file;
        }

        public List<ClassFileObject> getGeneratedOutputFiles() {
            return outputFiles;
        }
    }

    /**
     * 自定义类加载器
     */
    private static class CompiledClassLoader extends ClassLoader {
        private final List<ClassFileObject> files;

        private CompiledClassLoader(List<ClassFileObject> files) {
            this.files = files;
        }

        @Override
        protected Class<?> findClass(String name) throws ClassNotFoundException {
            Iterator<ClassFileObject> itr = files.iterator();
            while (itr.hasNext()) {
                ClassFileObject file = itr.next();
                if (file.getClassName().equals(name)) {
                    itr.remove();
                    byte[] bytes = file.getBytes();
                    return super.defineClass(name, bytes, 0, bytes.length);
                }
            }
            return super.findClass(name);
        }
    }
}
